package com.open.controller;

import cn.hutool.core.util.StrUtil;
import cn.hutool.core.util.URLUtil;
import cn.hutool.json.JSONObject;
import com.open.bean.*;
import com.open.component.JwtComponent;
import com.open.result.OpenResult;
import com.open.server.BaiduImageRecognitionServer;
import com.open.server.MsgSocketServer;
import com.open.service.*;
import com.open.util.IPUtil;
import com.open.util.OpenRequestUtil;
import com.open.util.RedisUtil;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.task.TaskExecutor;
import org.springframework.web.bind.annotation.*;

import javax.servlet.http.HttpServletRequest;
import java.nio.charset.Charset;
import java.util.Base64;
import java.util.Date;
import java.util.List;

/**
 * @author tanyongpeng
 * <p>des</p>
 **/
@RestController
@RequestMapping("/request/open")
@Slf4j
@AllArgsConstructor
public class WebSocketController {

    private final JwtComponent jwtComponent;

    private final IpWhiteService ipWhiteService;

    private final RequestRecordService requestRecordService;

    private final RequestTokenService requestTokenService;

    private final OpenDictService openDictService;

    private final MsgSocketServer msgSocketServer;

    private final TaskExecutor taskExecutor;

    private final OpenSystemRoleService openSystemRoleService;


    private final FileCacheService fileCacheService;

    @PostMapping("/send/webSocket/msg")
    public OpenResult webSocketMsg(@RequestBody JSONObject jsonObject, HttpServletRequest request){

        OpenDict openDict = openDictService.lambdaQuery()
                .eq(OpenDict::getDictCode, "100003").one();

        if (Integer.parseInt(openDict.getDictValue()) != 0){
            return OpenResult.error(openDict.getDictDes());
        }

        if (StrUtil.isBlank(jsonObject.getStr("modelType"))){
            return OpenResult.error("模型不能为空");
        }

        String modelType = new String(Base64.getDecoder().decode(jsonObject.getStr("modelType")));

        if (!(modelType.equals("gpt-3.5-turbo") || modelType.equals("gpt-3.5-turbo-0301") || modelType.equals("gpt-3.5-turbo-identify") ||
                modelType.equals("file-gpt-3.5"))){
            return OpenResult.error("模型格式不正确");
        }

        if (modelType.equals("file-gpt-3.5")){
            modelType = "gpt-3.5-turbo";
        }

        String imgTag = "";
        if (modelType.equals("gpt-3.5-turbo-identify")){
            String content = URLUtil.decode(jsonObject.getStr("content"), Charset.defaultCharset());
            if (content.indexOf("[[https://plumgpt.com/chat") != content.lastIndexOf("[[https://plumgpt.com/chat")){
                return OpenResult.error("当前文本域只能上传一张图片");
            }
            if (StrUtil.isNotBlank(jsonObject.getStr("imgUrl"))) {
                String recognition = BaiduImageRecognitionServer.getImg(jsonObject.getStr("imgUrl"));
                if (StrUtil.isNotBlank(recognition)){
                    imgTag = recognition;
                }
            }
        }

        String wId = jwtComponent.getUserWId(request);
        if (StrUtil.isBlank(jsonObject.getStr("content")) || StrUtil.isBlank(jsonObject.getStr("code"))) {
            return OpenResult.error("请求参数不能为空");
        }
        IpWhite ipWhite = ipWhiteService.lambdaQuery().eq(IpWhite::getWid, wId).one();
        Integer msgState = RedisUtil.get(wId+"sendMsg");
        if (msgState != null) {
            if (msgState != 0) {
                return OpenResult.error("上次请求结果未返回，请耐心等待或刷新页面，如一直没有返回，则在2分钟后重新发起请求");
            }
        }
        List<RequestToken> requestTokenList = requestTokenService.lambdaQuery().eq(RequestToken::getSignStatus, 0)
                .eq(RequestToken::getTokenType, 6)
                .gt(RequestToken::getBalance, 0).list();
        if (requestTokenList.size() == 0) {
            RedisUtil.set(wId+"sendMsg",0,120);
            return OpenResult.error("当前没有机器可发送消息，请联系管理员");
        }
        if (ipWhite.getRequestCount() <= 0) {
            RedisUtil.set(wId+"sendMsg",0,120);
            return OpenResult.error("你的余额已经用完！如果想继续可加群聊：237343691，联系群主");
        }

        Integer userCount = requestRecordService.lambdaQuery()
                .eq(RequestRecord::getWid, wId)
                .last("AND TO_DAYS(request_time) = TO_DAYS(NOW())").count();

        // 判断今日是否限额
        if (userCount >= ipWhite.getLimitationCount()){
            return OpenResult.error("超出限额使用，具体查看首页 我的账户");
        }
        String code = jsonObject.getStr("code");
        List<FileCache> fileCaches = fileCacheService
                .lambdaQuery()
                .eq(FileCache::getMsgCode,code ).list();
        log.info("当前code值：{}",code);
        String fileText = "";
        if (fileCaches.size() != 0){
            fileText = fileCaches.get(0).getFileText();
        }
        msgSocketServer.msg(wId+"_"+jsonObject.getStr("browserId"),jsonObject.getStr("code"),jsonObject.getStr("content"),
                requestRecordService,requestTokenList,taskExecutor,jsonObject.getStr("system"),openSystemRoleService,modelType,ipWhite,imgTag,fileText);
        return OpenResult.success();
    }

    @GetMapping("/check/balance")
    public OpenResult checkBalance(HttpServletRequest request){
        String userWId = jwtComponent.getUserWId(request);
        IpWhite ipWhite = ipWhiteService.getById(userWId);
        if (ipWhite.getRequestCount() <= 0) {
            return OpenResult.success(false);
        }else {
            return OpenResult.success(true);
        }
    }


    @PostMapping("/msg/save")
    public OpenResult msgSave(@RequestBody JSONObject jsonObject,HttpServletRequest request){
        RequestRecord requestRecord = new RequestRecord();
        String wId = jwtComponent.getUserWId(request);
        IpWhite ipWhite = ipWhiteService.getById(wId);
        requestRecord.setRequestIp(IPUtil.getIpAddress(request));
        requestRecord.setWid(Integer.parseInt(wId));
        requestRecord.setRequestSign(request.getHeader(OpenRequestUtil.REQUEST_HEADER_SING));
        requestRecord.setRequestTime(new Date());
        requestRecord.setRequestContent(URLUtil.decode(jsonObject.getStr("requestMsg"), Charset.defaultCharset()));
        requestRecord.setRespondContent(URLUtil.decode(jsonObject.getStr("responseMsg"), Charset.defaultCharset()));
        requestRecord.setGroupCode(Long.parseLong(jsonObject.getStr("code")));
        requestRecord.setRequestStatus(1);
        requestRecord.setRespondTime(new Date());
        requestRecord.setRequestMsg("请求成功");
        requestRecord.setNikeName(ipWhite.getNikeName());
        requestRecord.setUserType(ipWhite.getUserType());
        requestRecord.setMsgType(jsonObject.getStr("msgType"));
        requestRecord.setMsgResources(jsonObject.getStr("msgResources"));
        requestRecord.setModelType(new String(Base64.getDecoder().decode(jsonObject.getStr("modelType"))));
        if (StrUtil.isNotBlank(jsonObject.getStr("system"))){
            OpenSystemRole system = openSystemRoleService.getById(jsonObject.getStr("system"));
            requestRecord.setSystemName(system.getRoleName());
            requestRecord.setSystemValue(system.getRoleValue());
            requestRecord.setSystemRid(system.getRid());
        }
        int balance = getBalance(ipWhite);
        requestRecord.setConsumptionLimit(balance);
        requestRecord.setDialogueCount(ipWhite.getDialogueCount());
        requestRecord.setTemperature(ipWhite.getTemperature());
        ipWhiteService.lambdaUpdate()
                .set(IpWhite::getRequestCount,ipWhite.getRequestCount() - balance)
                .eq(IpWhite::getWid,ipWhite.getWid()).update();
        boolean save = requestRecordService.save(requestRecord);
        return OpenResult.save(save);
    }

    public int getBalance(IpWhite ipWhite){
        if (ipWhite.getDialogueCount() <= 3){
            return 1;
        }else if (ipWhite.getDialogueCount() >=4 && ipWhite.getDialogueCount() <= 6){
            return 3;
        }else if (ipWhite.getDialogueCount() >=7 && ipWhite.getDialogueCount() <= 9){
            return 5;
        }else {
            return 8;
        }
    }
}
